{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Workflow for a ReAct Agent\n",
    "\n",
    "This notebook walks through setting up a `Workflow` to construct a ReAct agent from (mostly) scratch.\n",
    "\n",
    "React calling agents work by prompting an LLM to either invoke tools/functions, or return a final response.\n",
    "\n",
    "Our workflow will be stateful with memory, and will be able to call the LLM to select tools and process incoming user messages."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -U llama-index"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "os.environ[\"OPENAI_API_KEY\"] = \"sk-proj-...\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### [Optional] Set up observability with Llamatrace\n",
    "\n",
    "Set up tracing to visualize each step in the workflow."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install \"llama-index-core>=0.10.43\" \"openinference-instrumentation-llama-index>=2\" \"opentelemetry-proto>=1.12.0\" opentelemetry-exporter-otlp opentelemetry-sdk"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from opentelemetry.sdk import trace as trace_sdk\n",
    "from opentelemetry.sdk.trace.export import SimpleSpanProcessor\n",
    "from opentelemetry.exporter.otlp.proto.http.trace_exporter import (\n",
    "    OTLPSpanExporter as HTTPSpanExporter,\n",
    ")\n",
    "from openinference.instrumentation.llama_index import LlamaIndexInstrumentor\n",
    "\n",
    "\n",
    "# Add Phoenix API Key for tracing\n",
    "PHOENIX_API_KEY = \"<YOUR-PHOENIX-API-KEY>\"\n",
    "os.environ[\"OTEL_EXPORTER_OTLP_HEADERS\"] = f\"api_key={PHOENIX_API_KEY}\"\n",
    "\n",
    "# Add Phoenix\n",
    "span_phoenix_processor = SimpleSpanProcessor(\n",
    "    HTTPSpanExporter(endpoint=\"https://app.phoenix.arize.com/v1/traces\")\n",
    ")\n",
    "\n",
    "# Add them to the tracer\n",
    "tracer_provider = trace_sdk.TracerProvider()\n",
    "tracer_provider.add_span_processor(span_processor=span_phoenix_processor)\n",
    "\n",
    "# Instrument the application\n",
    "LlamaIndexInstrumentor().instrument(tracer_provider=tracer_provider)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Since workflows are async first, this all runs fine in a notebook. If you were running in your own code, you would want to use `asyncio.run()` to start an async event loop if one isn't already running.\n",
    "\n",
    "```python\n",
    "async def main():\n",
    "    <async code>\n",
    "\n",
    "if __name__ == \"__main__\":\n",
    "    import asyncio\n",
    "    asyncio.run(main())\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Designing the Workflow\n",
    "\n",
    "An agent consists of several steps\n",
    "1. Handling the latest incoming user message, including adding to memory and preparing the chat history\n",
    "2. Using the chat history and tools to construct a ReAct prompt\n",
    "3. Calling the llm with the react prompt, and parsing out function/tool calls\n",
    "4. If no tool calls, we can return\n",
    "5. If there are tool calls, we need to execute them, and then loop back for a fresh ReAct prompt using the latest tool calls\n",
    "\n",
    "### The Workflow Events\n",
    "\n",
    "To handle these steps, we need to define a few events:\n",
    "1. An event to handle new messages and prepare the chat history\n",
    "2. An event to stream the LLM response\n",
    "3. An event to prompt the LLM with the react prompt\n",
    "4. An event to trigger tool calls, if any\n",
    "5. An event to handle the results of tool calls, if any\n",
    "\n",
    "The other steps will use the built-in `StartEvent` and `StopEvent` events.\n",
    "\n",
    "In addition to events, we will also use the global context to store the current react reasoning!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from llama_index.core.llms import ChatMessage\n",
    "from llama_index.core.tools import ToolSelection, ToolOutput\n",
    "from llama_index.core.workflow import Event\n",
    "\n",
    "\n",
    "class PrepEvent(Event):\n",
    "    pass\n",
    "\n",
    "\n",
    "class InputEvent(Event):\n",
    "    input: list[ChatMessage]\n",
    "\n",
    "\n",
    "class StreamEvent(Event):\n",
    "    delta: str\n",
    "\n",
    "\n",
    "class ToolCallEvent(Event):\n",
    "    tool_calls: list[ToolSelection]\n",
    "\n",
    "\n",
    "class FunctionOutputEvent(Event):\n",
    "    output: ToolOutput"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The Workflow Itself\n",
    "\n",
    "With our events defined, we can construct our workflow and steps. \n",
    "\n",
    "Note that the workflow automatically validates itself using type annotations, so the type annotations on our steps are very helpful!"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Any, List\n",
    "\n",
    "from llama_index.core.agent.react import ReActChatFormatter, ReActOutputParser\n",
    "from llama_index.core.agent.react.types import (\n",
    "    ActionReasoningStep,\n",
    "    ObservationReasoningStep,\n",
    ")\n",
    "from llama_index.core.llms.llm import LLM\n",
    "from llama_index.core.memory import ChatMemoryBuffer\n",
    "from llama_index.core.tools.types import BaseTool\n",
    "from llama_index.core.workflow import (\n",
    "    Context,\n",
    "    Workflow,\n",
    "    StartEvent,\n",
    "    StopEvent,\n",
    "    step,\n",
    ")\n",
    "from llama_index.llms.openai import OpenAI\n",
    "\n",
    "\n",
    "class ReActAgent(Workflow):\n",
    "    def __init__(\n",
    "        self,\n",
    "        *args: Any,\n",
    "        llm: LLM | None = None,\n",
    "        tools: list[BaseTool] | None = None,\n",
    "        extra_context: str | None = None,\n",
    "        **kwargs: Any,\n",
    "    ) -> None:\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self.tools = tools or []\n",
    "        self.llm = llm or OpenAI()\n",
    "        self.formatter = ReActChatFormatter.from_defaults(\n",
    "            context=extra_context or \"\"\n",
    "        )\n",
    "        self.output_parser = ReActOutputParser()\n",
    "\n",
    "    @step\n",
    "    async def new_user_msg(self, ctx: Context, ev: StartEvent) -> PrepEvent:\n",
    "        # clear sources\n",
    "        await ctx.store.set(\"sources\", [])\n",
    "\n",
    "        # init memory if needed\n",
    "        memory = await ctx.store.get(\"memory\", default=None)\n",
    "        if not memory:\n",
    "            memory = ChatMemoryBuffer.from_defaults(llm=self.llm)\n",
    "\n",
    "        # get user input\n",
    "        user_input = ev.input\n",
    "        user_msg = ChatMessage(role=\"user\", content=user_input)\n",
    "        memory.put(user_msg)\n",
    "\n",
    "        # clear current reasoning\n",
    "        await ctx.store.set(\"current_reasoning\", [])\n",
    "\n",
    "        # set memory\n",
    "        await ctx.store.set(\"memory\", memory)\n",
    "\n",
    "        return PrepEvent()\n",
    "\n",
    "    @step\n",
    "    async def prepare_chat_history(\n",
    "        self, ctx: Context, ev: PrepEvent\n",
    "    ) -> InputEvent:\n",
    "        # get chat history\n",
    "        memory = await ctx.store.get(\"memory\")\n",
    "        chat_history = memory.get()\n",
    "        current_reasoning = await ctx.store.get(\n",
    "            \"current_reasoning\", default=[]\n",
    "        )\n",
    "\n",
    "        # format the prompt with react instructions\n",
    "        llm_input = self.formatter.format(\n",
    "            self.tools, chat_history, current_reasoning=current_reasoning\n",
    "        )\n",
    "        return InputEvent(input=llm_input)\n",
    "\n",
    "    @step\n",
    "    async def handle_llm_input(\n",
    "        self, ctx: Context, ev: InputEvent\n",
    "    ) -> ToolCallEvent | StopEvent:\n",
    "        chat_history = ev.input\n",
    "        current_reasoning = await ctx.store.get(\n",
    "            \"current_reasoning\", default=[]\n",
    "        )\n",
    "        memory = await ctx.store.get(\"memory\")\n",
    "\n",
    "        response_gen = await self.llm.astream_chat(chat_history)\n",
    "        async for response in response_gen:\n",
    "            ctx.write_event_to_stream(StreamEvent(delta=response.delta or \"\"))\n",
    "\n",
    "        try:\n",
    "            reasoning_step = self.output_parser.parse(response.message.content)\n",
    "            current_reasoning.append(reasoning_step)\n",
    "\n",
    "            if reasoning_step.is_done:\n",
    "                memory.put(\n",
    "                    ChatMessage(\n",
    "                        role=\"assistant\", content=reasoning_step.response\n",
    "                    )\n",
    "                )\n",
    "                await ctx.store.set(\"memory\", memory)\n",
    "                await ctx.store.set(\"current_reasoning\", current_reasoning)\n",
    "\n",
    "                sources = await ctx.store.get(\"sources\", default=[])\n",
    "\n",
    "                return StopEvent(\n",
    "                    result={\n",
    "                        \"response\": reasoning_step.response,\n",
    "                        \"sources\": [sources],\n",
    "                        \"reasoning\": current_reasoning,\n",
    "                    }\n",
    "                )\n",
    "            elif isinstance(reasoning_step, ActionReasoningStep):\n",
    "                tool_name = reasoning_step.action\n",
    "                tool_args = reasoning_step.action_input\n",
    "                return ToolCallEvent(\n",
    "                    tool_calls=[\n",
    "                        ToolSelection(\n",
    "                            tool_id=\"fake\",\n",
    "                            tool_name=tool_name,\n",
    "                            tool_kwargs=tool_args,\n",
    "                        )\n",
    "                    ]\n",
    "                )\n",
    "        except Exception as e:\n",
    "            current_reasoning.append(\n",
    "                ObservationReasoningStep(\n",
    "                    observation=f\"There was an error in parsing my reasoning: {e}\"\n",
    "                )\n",
    "            )\n",
    "            await ctx.store.set(\"current_reasoning\", current_reasoning)\n",
    "\n",
    "        # if no tool calls or final response, iterate again\n",
    "        return PrepEvent()\n",
    "\n",
    "    @step\n",
    "    async def handle_tool_calls(\n",
    "        self, ctx: Context, ev: ToolCallEvent\n",
    "    ) -> PrepEvent:\n",
    "        tool_calls = ev.tool_calls\n",
    "        tools_by_name = {tool.metadata.get_name(): tool for tool in self.tools}\n",
    "        current_reasoning = await ctx.store.get(\n",
    "            \"current_reasoning\", default=[]\n",
    "        )\n",
    "        sources = await ctx.store.get(\"sources\", default=[])\n",
    "\n",
    "        # call tools -- safely!\n",
    "        for tool_call in tool_calls:\n",
    "            tool = tools_by_name.get(tool_call.tool_name)\n",
    "            if not tool:\n",
    "                current_reasoning.append(\n",
    "                    ObservationReasoningStep(\n",
    "                        observation=f\"Tool {tool_call.tool_name} does not exist\"\n",
    "                    )\n",
    "                )\n",
    "                continue\n",
    "\n",
    "            try:\n",
    "                tool_output = tool(**tool_call.tool_kwargs)\n",
    "                sources.append(tool_output)\n",
    "                current_reasoning.append(\n",
    "                    ObservationReasoningStep(observation=tool_output.content)\n",
    "                )\n",
    "            except Exception as e:\n",
    "                current_reasoning.append(\n",
    "                    ObservationReasoningStep(\n",
    "                        observation=f\"Error calling tool {tool.metadata.get_name()}: {e}\"\n",
    "                    )\n",
    "                )\n",
    "\n",
    "        # save new state in context\n",
    "        await ctx.store.set(\"sources\", sources)\n",
    "        await ctx.store.set(\"current_reasoning\", current_reasoning)\n",
    "\n",
    "        # prep the next iteraiton\n",
    "        return PrepEvent()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "And thats it! Let's explore the workflow we wrote a bit.\n",
    "\n",
    "`new_user_msg()`:\n",
    "Adds the user message to memory, and clears the global context to keep track of a fresh string of reasoning.\n",
    "\n",
    "`prepare_chat_history()`:\n",
    "Prepares the react prompt, using the chat history, tools, and current reasoning (if any)\n",
    "\n",
    "`handle_llm_input()`:\n",
    "Prompts the LLM with our react prompt, and uses some utility functions to parse the output. If there are no tool calls, we can stop and emit a `StopEvent`. Otherwise, we emit a `ToolCallEvent` to handle tool calls. Lastly, if there are no tool calls, and no final response, we simply loop again.\n",
    "\n",
    "`handle_tool_calls()`:\n",
    "Safely calls tools with error handling, adding the tool outputs to the current reasoning. Then, by emitting a `PrepEvent`, we loop around for another round of ReAct prompting and parsing."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Run the Workflow!\n",
    "\n",
    "**NOTE:** With loops, we need to be mindful of runtime. Here, we set a timeout of 120s."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running step new_user_msg\n",
      "Step new_user_msg produced event PrepEvent\n",
      "Running step prepare_chat_history\n",
      "Step prepare_chat_history produced event InputEvent\n",
      "Running step handle_llm_input\n",
      "Step handle_llm_input produced event StopEvent\n"
     ]
    }
   ],
   "source": [
    "from llama_index.core.tools import FunctionTool\n",
    "from llama_index.llms.openai import OpenAI\n",
    "\n",
    "\n",
    "def add(x: int, y: int) -> int:\n",
    "    \"\"\"Useful function to add two numbers.\"\"\"\n",
    "    return x + y\n",
    "\n",
    "\n",
    "def multiply(x: int, y: int) -> int:\n",
    "    \"\"\"Useful function to multiply two numbers.\"\"\"\n",
    "    return x * y\n",
    "\n",
    "\n",
    "tools = [\n",
    "    FunctionTool.from_defaults(add),\n",
    "    FunctionTool.from_defaults(multiply),\n",
    "]\n",
    "\n",
    "agent = ReActAgent(\n",
    "    llm=OpenAI(model=\"gpt-4o\"), tools=tools, timeout=120, verbose=True\n",
    ")\n",
    "\n",
    "ret = await agent.run(input=\"Hello!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Hello! How can I assist you today?\n"
     ]
    }
   ],
   "source": [
    "print(ret[\"response\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running step new_user_msg\n",
      "Step new_user_msg produced event PrepEvent\n",
      "Running step prepare_chat_history\n",
      "Step prepare_chat_history produced event InputEvent\n",
      "Running step handle_llm_input\n",
      "Step handle_llm_input produced event ToolCallEvent\n",
      "Running step handle_tool_calls\n",
      "Step handle_tool_calls produced event PrepEvent\n",
      "Running step prepare_chat_history\n",
      "Step prepare_chat_history produced event InputEvent\n",
      "Running step handle_llm_input\n",
      "Step handle_llm_input produced event ToolCallEvent\n",
      "Running step handle_tool_calls\n",
      "Step handle_tool_calls produced event PrepEvent\n",
      "Running step prepare_chat_history\n",
      "Step prepare_chat_history produced event InputEvent\n",
      "Running step handle_llm_input\n",
      "Step handle_llm_input produced event StopEvent\n"
     ]
    }
   ],
   "source": [
    "ret = await agent.run(input=\"What is (2123 + 2321) * 312?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The result of (2123 + 2321) * 312 is 1,386,528.\n"
     ]
    }
   ],
   "source": [
    "print(ret[\"response\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Chat History\n",
    "\n",
    "By default, the workflow is creating a fresh `Context` for each run. This means that the chat history is not preserved between runs. However, we can pass our own `Context` to the workflow to preserve chat history."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running step new_user_msg\n",
      "Step new_user_msg produced event PrepEvent\n",
      "Running step prepare_chat_history\n",
      "Step prepare_chat_history produced event InputEvent\n",
      "Running step handle_llm_input\n",
      "Step handle_llm_input produced event StopEvent\n",
      "Hello, Logan! How can I assist you today?\n",
      "Running step new_user_msg\n",
      "Step new_user_msg produced event PrepEvent\n",
      "Running step prepare_chat_history\n",
      "Step prepare_chat_history produced event InputEvent\n",
      "Running step handle_llm_input\n",
      "Step handle_llm_input produced event StopEvent\n",
      "Your name is Logan.\n"
     ]
    }
   ],
   "source": [
    "from llama_index.core.workflow import Context\n",
    "\n",
    "ctx = Context(agent)\n",
    "\n",
    "ret = await agent.run(input=\"Hello! My name is Logan\", ctx=ctx)\n",
    "print(ret[\"response\"])\n",
    "\n",
    "ret = await agent.run(input=\"What is my name?\", ctx=ctx)\n",
    "print(ret[\"response\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Streaming\n",
    "\n",
    "We can also access the streaming response from the LLM, using the `handler` object returned from the `.run()` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Thought: The current language of the user is: English. I cannot use a tool to help me answer the question.\n",
      "Answer: Why don't scientists trust atoms? Because they make up everything!"
     ]
    }
   ],
   "source": [
    "agent = ReActAgent(\n",
    "    llm=OpenAI(model=\"gpt-4o\"), tools=tools, timeout=120, verbose=False\n",
    ")\n",
    "\n",
    "handler = agent.run(input=\"Hello! Tell me a joke.\")\n",
    "\n",
    "async for event in handler.stream_events():\n",
    "    if isinstance(event, StreamEvent):\n",
    "        print(event.delta, end=\"\", flush=True)\n",
    "\n",
    "response = await handler\n",
    "# print(response)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama-index-cDlKpkFt-py3.11",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
